"""
© 2021 This work is licensed under a CC-BY-NC-SA license.
Title: *"Behavioral cloning in recurrent spiking networks: A comprehensive framework"*
**Authors:** Anonymus
"""

"""

This script reproduces the Fig.5E-F.
The panels are saved in the folder "figures"

"""


import matplotlib.pyplot as plt
import numpy as np
from numpy import savetxt,loadtxt
import pylab as py
import os.path
import os

#### This script reproduces Figure 5E-F of the maintext

folder = "RewardsErrors_presaved"
folder_save = "figures"

def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w


n_iter_out = 8
number_of_reps = 10
sigma_teach =  3.0
smooth_window = 1

log_tau_vals = np.linspace(-1.5,1.5,10)
tau_vals = 10**(log_tau_vals)

max_avg = []
max_avg_3 = []

max_std = []

all_rewards = np.zeros(( 0 , n_iter_out - smooth_window+1),dtype=int)


n_reps = 10
rank = 500

for k in range(len(tau_vals)):

    avg_reward = np.zeros(( 0 , n_iter_out - smooth_window+1),dtype=int)
    max_rewards = []

    for n_rep in range(0,number_of_reps):
        path = "avg_reward" + "_st" + str(sigma_teach) + "_" + str(n_rep) + "_not_clumped_tau" + str(tau_vals[k]) + ".npy"
        arr = np.load((os.path.join(folder,path) ) )#)

        arr = np.matrix( moving_average(arr,smooth_window) )
        avg_reward = np.concatenate((avg_reward ,arr ))
        max_rewards.append(np.max(arr))
    all_rewards = np.concatenate((all_rewards ,np.median(avg_reward,0) ))

    max_avg.append(np.max(np.median(avg_reward,0).T))
    max_avg_3.append(np.median(max_rewards))
    max_std.append(np.std(max_rewards))

DS_mean = []
DS_std = []

for k in range(len(tau_vals)):
    DS_rep = []

    for n_rep in range(number_of_reps):
        path = "DS" + "_st" + str(sigma_teach) + "_" + str(n_rep) + "_not_clumped_tau" + str(tau_vals[k]) + ".npy"
        DS = np.load((os.path.join(folder,path) ) )

        DS_rep.append(DS)
    DS_mean.append(np.mean(DS_rep))
    DS_std.append(np.std(DS_rep))


fig =plt.figure()
fig.set_size_inches(3, 5)

plt.subplot(211)
plt.errorbar(tau_vals, DS_mean, DS_std/np.sqrt(number_of_reps))
plt.xscale('log')
plt.yscale('log')
plt.ylabel('$\Delta S$')

plt.subplot(212)

plt.plot( tau_vals, max_avg_3  ,'o')
plt.errorbar( tau_vals, max_avg_3, max_std/np.sqrt(number_of_reps)  )
plt.xscale('log')
plt.ylabel('reward')
plt.xlabel('$\u03C4_*$')

plt.tight_layout()

path = "Fig5.eps"
plt.savefig(os.path.join(folder_save,path)  , format='eps')
plt.show()
